package aws

import (
	"context"
	"strings"

	"github.com/aws/aws-sdk-go-v2/aws"
	"github.com/aws/aws-sdk-go-v2/service/dynamodb"
	"github.com/aws/aws-sdk-go-v2/service/dynamodb/types"
	"github.com/aws/smithy-go"
	"github.com/go-kit/log/level"
	"github.com/grafana/dskit/backoff"
	"github.com/grafana/dskit/instrument"
	"github.com/pkg/errors"
	"github.com/prometheus/client_golang/prometheus"
	"golang.org/x/time/rate"

	"github.com/grafana/loki/v3/pkg/storage/config"
	"github.com/grafana/loki/v3/pkg/storage/stores/series/index"
	"github.com/grafana/loki/v3/pkg/util/log"
)

const (
	errCodeSlowDown = "SlowDown"
)

// Pluggable auto-scaler implementation
type autoscale interface {
	PostCreateTable(ctx context.Context, desc config.TableDesc) error
	// This whole interface is very similar to chunk.TableClient, but
	// DescribeTable needs to mutate desc
	DescribeTable(ctx context.Context, desc *config.TableDesc) error
	UpdateTable(ctx context.Context, current config.TableDesc, expected *config.TableDesc) error
}

type callManager struct {
	limiter       *rate.Limiter
	backoffConfig backoff.Config
}

type dynamoTableClient struct {
	DynamoDB    dynamoClient
	callManager callManager
	autoscale   autoscale
	metrics     *dynamoDBMetrics
	kmsKeyID    string
}

// NewDynamoDBTableClient makes a new DynamoTableClient.
func NewDynamoDBTableClient(cfg DynamoDBConfig, reg prometheus.Registerer) (index.TableClient, error) {
	dynamoDB, err := dynamoClientFromURL(cfg.DynamoDB.URL)
	if err != nil {
		return nil, err
	}

	callManager := callManager{
		limiter:       rate.NewLimiter(rate.Limit(cfg.APILimit), 1),
		backoffConfig: cfg.BackoffConfig,
	}

	var autoscale autoscale
	if cfg.Metrics.URL != "" {
		autoscale, err = newMetricsAutoScaling(cfg)
		if err != nil {
			return nil, err
		}
	}

	return dynamoTableClient{
		DynamoDB:    &dynamoDB,
		callManager: callManager,
		autoscale:   autoscale,
		metrics:     newMetrics(reg),
		kmsKeyID:    cfg.KMSKeyID,
	}, nil
}

func (d dynamoTableClient) Stop() {
}

func (d dynamoTableClient) backoffAndRetry(ctx context.Context, fn func(context.Context) error) error {
	return d.callManager.backoffAndRetry(ctx, fn)
}

func (d callManager) backoffAndRetry(ctx context.Context, fn func(context.Context) error) error {
	if d.limiter != nil { // Tests will have a nil limiter.
		_ = d.limiter.Wait(ctx)
	}

	backoff := backoff.New(ctx, d.backoffConfig)
	for backoff.Ongoing() {
		if err := fn(ctx); err != nil {
			var apiErr smithy.APIError
			if errors.As(err, &apiErr) && apiErr.ErrorCode() == errCodeSlowDown {
				level.Warn(log.WithContext(ctx, log.Logger)).Log("msg", "got error, backing off and retrying", "err", err, "retry", backoff.NumRetries())
				backoff.Wait()
				continue
			} else { //nolint:revive
				return err
			}
		}
		return nil
	}
	return backoff.Err()
}

func (d dynamoTableClient) ListTables(ctx context.Context) ([]string, error) {
	var resp *dynamodb.ListTablesOutput
	err := d.backoffAndRetry(ctx, func(ctx context.Context) error {
		return instrument.CollectedRequest(ctx, "DynamoDB.ListTablesPages", d.metrics.dynamoRequestDuration, instrument.ErrorCode, func(ctx context.Context) error {
			var err error
			resp, err = d.DynamoDB.ListTables(ctx, &dynamodb.ListTablesInput{})
			return err
		})
	})
	return resp.TableNames, err
}

func chunkTagsToDynamoDB(ts config.Tags) []types.Tag {
	var result []types.Tag
	for k, v := range ts {
		result = append(result, types.Tag{
			Key:   aws.String(k),
			Value: aws.String(v),
		})
	}
	return result
}

func (d dynamoTableClient) CreateTable(ctx context.Context, desc config.TableDesc) error {
	var tableARN *string
	if err := d.backoffAndRetry(ctx, func(ctx context.Context) error {
		return instrument.CollectedRequest(ctx, "DynamoDB.CreateTable", d.metrics.dynamoRequestDuration, instrument.ErrorCode, func(ctx context.Context) error {
			input := &dynamodb.CreateTableInput{
				TableName: aws.String(desc.Name),
				AttributeDefinitions: []types.AttributeDefinition{
					{
						AttributeName: aws.String(hashKey),
						AttributeType: types.ScalarAttributeTypeS,
					},
					{
						AttributeName: aws.String(rangeKey),
						AttributeType: types.ScalarAttributeTypeB,
					},
				},
				KeySchema: []types.KeySchemaElement{
					{
						AttributeName: aws.String(hashKey),
						KeyType:       types.KeyTypeHash,
					},
					{
						AttributeName: aws.String(rangeKey),
						KeyType:       types.KeyTypeRange,
					},
				},
			}

			if desc.UseOnDemandIOMode {
				input.BillingMode = types.BillingModePayPerRequest
			} else {
				input.BillingMode = types.BillingModeProvisioned
				input.ProvisionedThroughput = &types.ProvisionedThroughput{
					ReadCapacityUnits:  aws.Int64(desc.ProvisionedRead),
					WriteCapacityUnits: aws.Int64(desc.ProvisionedWrite),
				}
			}

			if d.kmsKeyID != "" {
				sseSpecification := &types.SSESpecification{
					Enabled:        aws.Bool(true),
					SSEType:        types.SSETypeKms,
					KMSMasterKeyId: aws.String(d.kmsKeyID),
				}
				input.SSESpecification = sseSpecification
			}

			output, err := d.DynamoDB.CreateTable(ctx, input)
			if err != nil {
				return err
			}
			if output.TableDescription != nil {
				tableARN = output.TableDescription.TableArn
			}
			return nil
		})
	}); err != nil {
		return err
	}

	if d.autoscale != nil {
		err := d.autoscale.PostCreateTable(ctx, desc)
		if err != nil {
			return err
		}
	}

	tags := chunkTagsToDynamoDB(desc.Tags)
	if len(tags) > 0 {
		return d.backoffAndRetry(ctx, func(ctx context.Context) error {
			return instrument.CollectedRequest(ctx, "DynamoDB.TagResource", d.metrics.dynamoRequestDuration, instrument.ErrorCode, func(ctx context.Context) error {
				_, err := d.DynamoDB.TagResource(ctx, &dynamodb.TagResourceInput{
					ResourceArn: tableARN,
					Tags:        tags,
				})
				if relevantError(err) {
					return err
				}
				return nil
			})
		})
	}
	return nil
}

func (d dynamoTableClient) DeleteTable(ctx context.Context, name string) error {
	return d.backoffAndRetry(ctx, func(ctx context.Context) error {
		return instrument.CollectedRequest(ctx, "DynamoDB.DeleteTable", d.metrics.dynamoRequestDuration, instrument.ErrorCode, func(ctx context.Context) error {
			input := &dynamodb.DeleteTableInput{TableName: aws.String(name)}
			_, err := d.DynamoDB.DeleteTable(ctx, input)
			if err != nil {
				return err
			}

			return nil
		})
	})
}

func (d dynamoTableClient) DescribeTable(ctx context.Context, name string) (desc config.TableDesc, isActive bool, err error) {
	var tableARN *string
	err = d.backoffAndRetry(ctx, func(ctx context.Context) error {
		return instrument.CollectedRequest(ctx, "DynamoDB.DescribeTable", d.metrics.dynamoRequestDuration, instrument.ErrorCode, func(ctx context.Context) error {
			out, err := d.DynamoDB.DescribeTable(ctx, &dynamodb.DescribeTableInput{
				TableName: aws.String(name),
			})
			if err != nil {
				return err
			}
			desc.Name = name
			if out.Table != nil {
				if provision := out.Table.ProvisionedThroughput; provision != nil {
					if provision.ReadCapacityUnits != nil {
						desc.ProvisionedRead = *provision.ReadCapacityUnits
					}
					if provision.WriteCapacityUnits != nil {
						desc.ProvisionedWrite = *provision.WriteCapacityUnits
					}
				}
				isActive = (out.Table.TableStatus == types.TableStatusActive)
				if out.Table.BillingModeSummary != nil {
					desc.UseOnDemandIOMode = out.Table.BillingModeSummary.BillingMode == types.BillingModePayPerRequest
				}
				tableARN = out.Table.TableArn
			}
			return err
		})
	})
	if err != nil {
		return
	}

	err = d.backoffAndRetry(ctx, func(ctx context.Context) error {
		return instrument.CollectedRequest(ctx, "DynamoDB.ListTagsOfResource", d.metrics.dynamoRequestDuration, instrument.ErrorCode, func(ctx context.Context) error {
			out, err := d.DynamoDB.ListTagsOfResource(ctx, &dynamodb.ListTagsOfResourceInput{
				ResourceArn: tableARN,
			})
			if relevantError(err) {
				return err
			}
			desc.Tags = make(map[string]string, len(out.Tags))
			for _, tag := range out.Tags {
				desc.Tags[*tag.Key] = *tag.Value
			}
			return nil
		})
	})

	if d.autoscale != nil {
		err = d.autoscale.DescribeTable(ctx, &desc)
	}
	return
}

// Filter out errors that we don't want to see
// (currently only relevant in integration tests)
func relevantError(err error) bool {
	if err == nil {
		return false
	}
	if strings.Contains(err.Error(), "Tagging is not currently supported in DynamoDB Local.") {
		return false
	}
	return true
}

func (d dynamoTableClient) UpdateTable(ctx context.Context, current, expected config.TableDesc) error {
	if d.autoscale != nil {
		err := d.autoscale.UpdateTable(ctx, current, &expected)
		if err != nil {
			return err
		}
	}
	level.Debug(log.Logger).Log("msg", "Updating Table",
		"expectedWrite", expected.ProvisionedWrite,
		"currentWrite", current.ProvisionedWrite,
		"expectedRead", expected.ProvisionedRead,
		"currentRead", current.ProvisionedRead,
		"expectedOnDemandMode", expected.UseOnDemandIOMode,
		"currentOnDemandMode", current.UseOnDemandIOMode)
	if (current.ProvisionedRead != expected.ProvisionedRead ||
		current.ProvisionedWrite != expected.ProvisionedWrite) &&
		!expected.UseOnDemandIOMode {
		level.Info(log.Logger).Log("msg", "updating provisioned throughput on table", "table", expected.Name, "old_read", current.ProvisionedRead, "old_write", current.ProvisionedWrite, "new_read", expected.ProvisionedRead, "new_write", expected.ProvisionedWrite)
		if err := d.backoffAndRetry(ctx, func(ctx context.Context) error {
			return instrument.CollectedRequest(ctx, "DynamoDB.UpdateTable", d.metrics.dynamoRequestDuration, instrument.ErrorCode, func(ctx context.Context) error {
				var dynamoBillingMode types.BillingMode
				updateTableInput := &dynamodb.UpdateTableInput{
					TableName: aws.String(expected.Name),
					ProvisionedThroughput: &types.ProvisionedThroughput{
						ReadCapacityUnits:  aws.Int64(expected.ProvisionedRead),
						WriteCapacityUnits: aws.Int64(expected.ProvisionedWrite),
					},
				}
				// we need this to be a separate check for the billing mode, as aws returns
				// an error if we set a table to the billing mode it is currently on.
				if current.UseOnDemandIOMode != expected.UseOnDemandIOMode {
					dynamoBillingMode = types.BillingModeProvisioned
					level.Info(log.Logger).Log("msg", "updating billing mode on table", "table", expected.Name, "old_mode", current.UseOnDemandIOMode, "new_mode", expected.UseOnDemandIOMode)
					updateTableInput.BillingMode = dynamoBillingMode
				}

				_, err := d.DynamoDB.UpdateTable(ctx, updateTableInput)
				return err
			})
		}); err != nil {
			recordDynamoError(expected.Name, err, "DynamoDB.UpdateTable", d.metrics)
			var apiErr smithy.APIError
			if errors.As(err, &apiErr) && apiErr.ErrorCode() == errCodeSlowDown {
				level.Warn(log.Logger).Log("msg", "update limit exceeded", "err", err)
			} else {
				return err
			}
		}
	} else if expected.UseOnDemandIOMode && current.UseOnDemandIOMode != expected.UseOnDemandIOMode {
		// moved the enabling of OnDemand mode to it's own block to reduce complexities & interactions with the various
		// settings used in provisioned mode. Unfortunately the boilerplate wrappers for retry and tracking needed to be copied.
		if err := d.backoffAndRetry(ctx, func(ctx context.Context) error {
			return instrument.CollectedRequest(ctx, "DynamoDB.UpdateTable", d.metrics.dynamoRequestDuration, instrument.ErrorCode, func(ctx context.Context) error {
				level.Info(log.Logger).Log("msg", "updating billing mode on table", "table", expected.Name, "old_mode", current.UseOnDemandIOMode, "new_mode", expected.UseOnDemandIOMode)
				updateTableInput := &dynamodb.UpdateTableInput{TableName: aws.String(expected.Name), BillingMode: types.BillingModePayPerRequest}
				_, err := d.DynamoDB.UpdateTable(ctx, updateTableInput)
				return err
			})
		}); err != nil {
			recordDynamoError(expected.Name, err, "DynamoDB.UpdateTable", d.metrics)
			var apiErr smithy.APIError
			if errors.As(err, &apiErr) && apiErr.ErrorCode() == errCodeSlowDown {
				level.Warn(log.Logger).Log("msg", "update limit exceeded", "err", err)
			} else {
				return err
			}
		}
	}

	if !current.Tags.Equals(expected.Tags) {
		var tableARN *string
		if err := d.backoffAndRetry(ctx, func(ctx context.Context) error {
			return instrument.CollectedRequest(ctx, "DynamoDB.DescribeTable", d.metrics.dynamoRequestDuration, instrument.ErrorCode, func(ctx context.Context) error {
				out, err := d.DynamoDB.DescribeTable(ctx, &dynamodb.DescribeTableInput{
					TableName: aws.String(expected.Name),
				})
				if err != nil {
					return err
				}
				if out.Table != nil {
					tableARN = out.Table.TableArn
				}
				return nil
			})
		}); err != nil {
			return err
		}

		return d.backoffAndRetry(ctx, func(ctx context.Context) error {
			return instrument.CollectedRequest(ctx, "DynamoDB.TagResource", d.metrics.dynamoRequestDuration, instrument.ErrorCode, func(ctx context.Context) error {
				_, err := d.DynamoDB.TagResource(ctx, &dynamodb.TagResourceInput{
					ResourceArn: tableARN,
					Tags:        chunkTagsToDynamoDB(expected.Tags),
				})
				if relevantError(err) {
					return errors.Wrap(err, "applying tags")
				}
				return nil
			})
		})
	}
	return nil
}
